
; (c) Daniel Llorens - 2012-2013
; Tests for (ploy reduce)

; This library is free software; you can redistribute it and/or modify it under
; the terms of the GNU General Public License as published by the Free
; Software Foundation; either version 3 of the License, or (at your option) any
; later version.

(import (srfi srfi-1) (srfi srfi-9) (srfi srfi-26) (ploy basic) (ploy test)
        (ploy ploy) (ploy reduce) (ploy as-array))

; -----------------------------
; reductions: over, folda/foldb, dot.
; -----------------------------

; null cases.

(T 3 (folda + 3 #()) (foldb + 3 #()))
(define v-max-sum (verb (lambda (c v) (max c (apply + (vector->list v)))) '() 0 1))

; there's 1 item, so (max c (+ first item)).
(T 0
   (folda v-max-sum -1 #2(()))
   (foldb v-max-sum -1 #2(())))
; no items, so c.
(T -1
   (folda v-max-sum -1 #2())
   (foldb v-max-sum -1 #2()))

; these examples from fig. 6.4 ss. of Rich2006.
; @TODO There's a reason why m +/, d $ work with items. Think about that.
(define +/ (verb (cut over (lambda (x y) (ply + x y)) <>) #f '_))
(define +/a (verb (cut folda + 0 <>) #f '_))
(define +/b (verb (cut foldb + 0 <>) #f '_))

(apply T
       #2((6 22 38) (54 70 86))
       (map (cute ply <> (i. 2 3 4))
            (list (w/rank +/ 1)
                  (w/rank +/ -2)
                  (w/rank (w/rank +/ 1) 2)
                  (w/rank +/a 1)
                  (w/rank +/a -2)
                  (w/rank (w/rank +/a 1) 2)
                  (w/rank +/b 1)
                  (w/rank +/b -2)
                  (w/rank (w/rank +/b 1) 2))))

(apply T
       #2((12 15 18 21) (48 51 54 57))
       (map (cute ply <> (i. 2 3 4))
            (list (w/rank +/ 2)
                  (w/rank +/ -1)
                  (w/rank +/a 2)
                  (w/rank +/a -1)
                  (w/rank +/b 2)
                  (w/rank +/b -1))))


(apply T
       #2((12 14 16 18) (20 22 24 26) (28 30 32 34))
       (map (cute ply <> (i. 2 3 4))
            (list (w/rank +/ 3)
                  +/
                  (w/rank +/a 3)
                  +/a
                  (w/rank +/b 3)
                  +/b)))

; other cases. @TODO Also benchmark :(
(define A (make-random-array '(10000 2) #:type #t))
(define B (array->list A))
(define max/ (verb (cut over (lambda (a b) (ply max a b)) <>) #f '_))
(define max/a (verb (cut folda max -inf.0 <>) #f '_))
(define max/b (verb (cut foldb max -inf.0 <>) #f '_))

(T-eps 0.0
       (list->array 1 (fold (lambda (a b) (map max a b)) (car B) (cdr B)))
       (ply max/ A)
       (ply max/a A)
       (ply max/b A))

(T-eps 0.0
       (list->array 1 (map (lambda (a) (fold max (car a) (cdr a))) B))
       (ply (w/rank max/ -1) A)
       (ply (w/rank max/a -1) A)
       (ply (w/rank max/b -1) A))

; @TODO folda uses ply on each op application, so on each scalar. Very slow.
; @TODO foldb suffers from apply map from.
(define (_sqr a) (* a a))
(define a (i. 10000))
(T-eps 0.0 (folda (lambda (c a) (+ c (_sqr a))) 0 a)
           (foldb (lambda (c a) (+ c (_sqr a))) 0 a)
           (let ((end (tally a)))
             (let loop ((i 0) (ac 0))
               (if (= i end)
                 ac
                 (loop (+ 1 i) (+ ac (_sqr (array-ref a i)))))))
           (let ((ac 0))
             (array-for-each (lambda (a) (set! ac (+ ac (* a a)))) a)
             ac))

; exercise folda / foldb with w/rank-using op. Args are rank 1 & 2.
(define* (xdota + * A B)
  (folda (w/rank ((@@ (ploy reduce) _madd) + *) '_ 0 1) 0 A B))

(define* (xdotb + * A B)
  (foldb (w/rank ((@@ (ploy reduce) _madd) + *) '_ 0 1) 0 A B))

(T-eps 0.0
       (dot + * (i. 100) (i. 100 20))
       (xdota + * (i. 100) (i. 100 20))
       (xdotb + * (i. 100) (i. 100 20))
       ;; (blas-dgemv (as-array (i. 100 20) #:type 'f64) (as-array (i. 100) #:type 'f64) 1. 'transpose)
       )

; @TODO Think about these:
;; (folda (w/rank + 0 1) 0 (i. 10 3)) ; works, it shouldn't (first op is (+ 0 #(...)), but next?)
;; (folda (w/rank + 1 1) 0 (i. 10 3)) ; doesn't work, probably ok (0 is not rank 1).
;; (folda (w/rank + 1 1) #(0 0 0) (i. 10 3)) ; works, properly
;; (folda (w/rank + 0 0) #(0 0 0) (i. 10 3)) ; works, properly
;; (folda + #(0 0 0) (i. 10 3)) ; works, properly
;; (ply (w/rank (verb (cut folda + 0 <>) #f '_) 1) (i. 10 3)) ; works, properly

; --------------------------------
; dot and dot variants
; --------------------------------

(define A #2((1 2) (3 4)))
(define B #2((1 3) (2 4)))
(define C #2((1 2 3) (4 5 6)))
(define a #1(10 20))
(define b #1(10 20 30))

(T (dot + * a a) 500)
(T (dot + * A a) #(50 110))
(T (dot + * A B) #2((5 11) (11 25)))
(T (dot + * C b) #(140 320))

; outer product versions (k i ... j ...)
; @TODO Find the transformation between all of these.
; @TODO Remove temporary / map / from overhead.

;; A: i ... k
;; B: k j ...

;; A must be transposed to place the fold axis first:
;; a: k i ...
;; b: k j ...

;; so w/rank layout under the fold is
;; c: i ... | j ...  ->  cell rank from j ...
;; a: i ... |        ->  cell rank 0
;; b:       | j ...  ->  cell rank from j ...

(define (make-kij fold-op)
  (lambda* (+_ *_ A B #:key type)
    (let ((type (or type (array-type* A))))
      (fold-op type
               (w/rank ((@@ (ploy reduce) _madd) +_ *_) (- (rank B) 1) 0 '_)
               (apply reshape 0 (append (drop-right ($ A) 1) (drop ($ B) 1)))
               (rollaxis A -1 0) B))))

(define kij-dota (make-kij folda/t))
(define kij-dotb (make-kij foldb/t))

(T
 (dot + * (i. 20 30) (i. 30 40 10))
 (kij-dota + * (i. 20 30) (i. 30 40 10))
 (kij-dotb + * (i. 20 30) (i. 30 40 10))
 )
